{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Reinplementation of Block Coordinate Descent (BCD) Algorithm for Training DNNs (10-layer MLP) for CIFAR-10 in PyTorch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTorch Version: 1.0.0\n",
      "Torchvision Version: 0.2.1\n",
      "GPU is available? True\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "5 runs, seed = 5, 10, 15, 8, 19; \n",
    "validation accuracies: \n",
    "\"\"\"\n",
    "from __future__ import print_function, division\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "from torchvision import datasets, models, transforms, utils\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "import os\n",
    "import copy\n",
    "\n",
    "print(\"PyTorch Version:\", torch.__version__)\n",
    "print(\"Torchvision Version:\", torchvision.__version__)\n",
    "print(\"GPU is available?\", torch.cuda.is_available())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Read in MNIST dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "dtype = torch.float\n",
    "# device = torch.device(\"cpu\") # Uncomment this to run on CPU\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\") # Uncomment this to run on GPU\n",
    "\n",
    "# Convert to tensor and scale to [0, 1]\n",
    "ts = transforms.Compose([transforms.ToTensor(), \n",
    "                             transforms.Normalize((0,), (1,))])\n",
    "cifar_trainset = datasets.CIFAR10('../data', train=True, download=True, transform=ts)\n",
    "cifar_testset = datasets.CIFAR10(root='../data', train=False, download=True, transform=ts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data manipulation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Manipulate train set\n",
    "x_d0 = cifar_trainset[0][0].size()[0]\n",
    "x_d1 = cifar_trainset[0][0].size()[1]\n",
    "x_d2 = cifar_trainset[0][0].size()[2]\n",
    "N = x_d4 = len(cifar_trainset)\n",
    "K = 10\n",
    "x_train = torch.empty((N, x_d0*x_d1*x_d2), device=device)\n",
    "y_train = torch.empty(N, dtype=torch.long)\n",
    "for i in range(N): \n",
    "     x_train[i,:] = torch.reshape(cifar_trainset[i][0], (1, x_d0*x_d1*x_d2))\n",
    "     y_train[i] = cifar_trainset[i][1]\n",
    "x_train = torch.t(x_train)\n",
    "y_one_hot = torch.zeros(N, K).scatter_(1, torch.reshape(y_train, (N, 1)), 1)\n",
    "y_one_hot = torch.t(y_one_hot).to(device=device)\n",
    "y_train = y_train.to(device=device)\n",
    "\n",
    "# Manipulate test set\n",
    "N_test = x_d3_test = len(cifar_testset)\n",
    "x_test = torch.empty((N_test, x_d0*x_d1*x_d2), device=device)\n",
    "y_test = torch.empty(N_test, dtype=torch.long)\n",
    "for i in range(N_test): \n",
    "     x_test[i,:] = torch.reshape(cifar_testset[i][0], (1, x_d0*x_d1*x_d2))\n",
    "     y_test[i] = cifar_testset[i][1]\n",
    "x_test = torch.t(x_test)\n",
    "y_test_one_hot = torch.zeros(N_test, K).scatter_(1, torch.reshape(y_test, (N_test, 1)), 1)\n",
    "y_test_one_hot = torch.t(y_test_one_hot).to(device=device)\n",
    "y_test = y_test.to(device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Main algorithm (Jinshan's Algorithm in Zeng et al. (2018))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define parameter initialization and forward pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialization of parameters\n",
    "torch.manual_seed(5)\n",
    "def initialize(dim_in, dim_out):\n",
    "    W = 0.01*torch.randn(dim_out, dim_in, device=device)\n",
    "    b = 0.1*torch.ones(dim_out, 1, device=device)\n",
    "    return W, b\n",
    "\n",
    "# Forward pass\n",
    "def feed_forward(weight, bias, activation, dim = N):\n",
    "    U = torch.addmm(bias.repeat(1, dim), weight, activation)\n",
    "    V = nn.ReLU()(U)\n",
    "    return U, V"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define functions for updating blocks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def updateV_js(U1,U2,W,b,rho,gamma): \n",
    "    _, d = W.size()\n",
    "    I = torch.eye(d, device=device)\n",
    "    U1 = nn.ReLU()(U1)\n",
    "    _, col_U2 = U2.size()\n",
    "    Vstar = torch.mm(torch.inverse(rho*(torch.mm(torch.t(W),W)) + gamma*I), \\\n",
    "                     rho*torch.mm(torch.t(W),U2-b.repeat(1,col_U2)) + gamma*U1)\n",
    "    return Vstar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def updateWb_js(U, V, W, b, alpha, rho): \n",
    "    d,N = V.size()\n",
    "    I = torch.eye(d, device=device)\n",
    "    _, col_U = U.size()\n",
    "    Wstar = torch.mm(alpha*W + rho*torch.mm(U - b.repeat(1,col_U),torch.t(V)),\\\n",
    "                     torch.inverse(alpha*I + rho*(torch.mm(V,torch.t(V)))))\n",
    "    bstar = (alpha*b+rho*torch.sum(U-torch.mm(W,V), dim=1).reshape(b.size()))/(rho*N + alpha)\n",
    "    return Wstar, bstar"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define the proximal operator of the ReLU activation function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def relu_prox(a, b, gamma, d, N):\n",
    "    val = torch.empty(d,N, device=device)\n",
    "    x = (a+gamma*b)/(1+gamma)\n",
    "    y = torch.min(b,torch.zeros(d,N, device=device))\n",
    "\n",
    "    val = torch.where(a+gamma*b < 0, y, torch.zeros(d,N, device=device))\n",
    "    val = torch.where(((a+gamma*b >= 0) & (b >=0)) | ((a*(gamma-np.sqrt(gamma*(gamma+1))) <= gamma*b) & (b < 0)), x, val)\n",
    "    val = torch.where((-a <= gamma*b) & (gamma*b <= a*(gamma-np.sqrt(gamma*(gamma+1)))), b, val)\n",
    "    return val"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "gamma = 1e1\n",
    "# gamma1 = gamma2 = gamma3 = gamma4 = gamma5 = gamma6 \\\n",
    "# = gamma7 = gamma8 = gamma9 = gamma10 = gamma11 = gamma\n",
    "\n",
    "rho = 1e1\n",
    "# rho1 = rho2 = rho3 = rho4 = rho5 = rho6 = rho7 = rho8 \n",
    "# = rho9 = rho10 = rho11 = rho \n",
    "\n",
    "\n",
    "alpha = 1\n",
    "# alpha1 = alpha2 = alpha3 = alpha4 = alpha5 = alpha6 = alpha7 \\\n",
    "# = alpha8 = alpha9 = alpha10 = alpha"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define block update"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def block_update(Wn, bn, Wn_1, bn_1, Vn, Un, Vn_1, Un_1, Vn_2, dn_1, alpha = alpha, gamma = gamma, rho = rho, dim = N):\n",
    "    # update W(n) and b(n)\n",
    "    Wn, bn = updateWb_js(Un, Vn_1, Wn, bn, alpha, rho)\n",
    "    # update V(n-1)\n",
    "    Vn_1 = updateV_js(Un_1, Un, Wn, bn, rho, gamma)\n",
    "    # update U(n-1)\n",
    "    Un_1 = relu_prox(Vn_1, (rho*torch.addmm(bn_1.repeat(1,dim), Wn_1, Vn_2) + \\\n",
    "                            alpha*Un_1)/(rho + alpha), (rho + alpha)/gamma, dn_1, dim)\n",
    "    return Wn, bn, Vn_1, Un_1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define loss computation of layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_loss(weight, bias, activation, preactivation, rho = rho):\n",
    "    loss = rho/2*torch.pow(torch.dist(torch.addmm(bias.repeat(1,N), \\\n",
    "                                                  weight, activation), preactivation, 2), 2).cpu().numpy()\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Parameter initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Layers: input + 3 hidden + output\n",
    "d0 = x_d0*x_d1*x_d2\n",
    "d1 = d2 = d3 = d4 = d5 = d6 \\\n",
    "= d7 = d8 = d9 = d10 = 700\n",
    "d11 = K \n",
    "\n",
    "\n",
    "W1, b1 = initialize(d0, d1)\n",
    "W2, b2 = initialize(d1, d2)\n",
    "W3, b3 = initialize(d2, d3)\n",
    "W4, b4 = initialize(d3, d4)\n",
    "W5, b5 = initialize(d4, d5)\n",
    "W6, b6 = initialize(d5, d6)\n",
    "W7, b7 = initialize(d6, d7)\n",
    "W8, b8 = initialize(d7, d8)\n",
    "W9, b9 = initialize(d8, d9)\n",
    "W10, b10 = initialize(d9, d10)\n",
    "W11, b11 = initialize(d10, d11)\n",
    "\n",
    "\n",
    "U1, V1 = feed_forward(W1, b1, x_train)\n",
    "U2, V2 = feed_forward(W2, b2, V1)\n",
    "U3, V3 = feed_forward(W3, b3, V2)\n",
    "# U4 = torch.addmm(b4.repeat(1, N), W4, V3)\n",
    "# V4 = U4\n",
    "U4, V4 = feed_forward(W4, b4, V3)\n",
    "U5, V5 = feed_forward(W5, b5, V4)\n",
    "U6, V6 = feed_forward(W6, b6, V5)\n",
    "U7, V7 = feed_forward(W7, b7, V6)\n",
    "U8, V8 = feed_forward(W8, b8, V7)\n",
    "U9, V9 = feed_forward(W9, b9, V8)\n",
    "U10, V10 = feed_forward(W10, b10, V9)\n",
    "U11 = torch.addmm(b11.repeat(1, N), W11, V10)\n",
    "V11 = U11\n",
    "\n",
    "niter = 300\n",
    "loss1 = np.empty(niter)\n",
    "loss2 = np.empty(niter)\n",
    "# layer1 = np.empty(niter)\n",
    "# layer2 = np.empty(niter)\n",
    "# layer3 = np.empty(niter)\n",
    "# layer4 = np.empty(niter)\n",
    "# layer11 = np.empty(niter)\n",
    "# layer21 = np.empty(niter)\n",
    "# layer31 = np.empty(niter)\n",
    "# layer41 = np.empty(niter)\n",
    "accuracy_train = np.empty(niter)\n",
    "accuracy_test = np.empty(niter)\n",
    "time1 = np.empty(niter)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 50000 samples, validate on 10000 samples\n",
      "Epoch 1 / 300 \n",
      " - time (s): 2.424466609954834 - sq_loss: 190972.12890625 - tot_loss: 362263826562.10144 - acc: 0.1 - val_acc: 0.1\n",
      "Epoch 2 / 300 \n",
      " - time (s): 2.2526965141296387 - sq_loss: 174005.60546875 - tot_loss: 5351093043376144.0 - acc: 0.0991 - val_acc: 0.0991\n",
      "Epoch 3 / 300 \n",
      " - time (s): 2.239732265472412 - sq_loss: 161939.609375 - tot_loss: 1.4645934686857762e+20 - acc: 0.1011 - val_acc: 0.1031\n",
      "Epoch 4 / 300 \n",
      " - time (s): 2.2757959365844727 - sq_loss: 148603.369140625 - tot_loss: 1.2197390704144783e+25 - acc: 0.10018 - val_acc: 0.0967\n",
      "Epoch 5 / 300 \n",
      " - time (s): 2.2521042823791504 - sq_loss: 136146.58203125 - tot_loss: 2.1870806370353496e+29 - acc: 0.09992 - val_acc: 0.1005\n",
      "Epoch 6 / 300 \n",
      " - time (s): 2.2903716564178467 - sq_loss: 124643.33984375 - tot_loss: 1.8361413345412208e+34 - acc: 0.10022 - val_acc: 0.1002\n",
      "Epoch 7 / 300 \n",
      " - time (s): 2.2813639640808105 - sq_loss: 114074.453125 - tot_loss: 7.42246544665113e+38 - acc: 0.10008 - val_acc: 0.0966\n",
      "Epoch 8 / 300 \n",
      " - time (s): 2.276777744293213 - sq_loss: 104384.9609375 - tot_loss: inf - acc: 0.09868 - val_acc: 0.1055\n",
      "Epoch 9 / 300 \n",
      " - time (s): 2.2785563468933105 - sq_loss: 95508.662109375 - tot_loss: inf - acc: 0.09998 - val_acc: 0.1001\n",
      "Epoch 10 / 300 \n",
      " - time (s): 2.282823085784912 - sq_loss: 87377.978515625 - tot_loss: inf - acc: 0.10002 - val_acc: 0.0999\n",
      "Epoch 11 / 300 \n",
      " - time (s): 2.2888951301574707 - sq_loss: 79935.0927734375 - tot_loss: inf - acc: 0.09966 - val_acc: 0.0993\n",
      "Epoch 12 / 300 \n",
      " - time (s): 2.296295642852783 - sq_loss: 73124.8095703125 - tot_loss: nan - acc: 0.1 - val_acc: 0.1\n",
      "Epoch 13 / 300 \n",
      " - time (s): 2.3004822731018066 - sq_loss: 66890.107421875 - tot_loss: nan - acc: 0.1 - val_acc: 0.1\n",
      "Epoch 14 / 300 \n",
      " - time (s): 2.427736282348633 - sq_loss: 61192.568359375 - tot_loss: nan - acc: 0.1 - val_acc: 0.1\n",
      "Epoch 15 / 300 \n",
      " - time (s): 2.4458000659942627 - sq_loss: 55980.5517578125 - tot_loss: nan - acc: 0.1 - val_acc: 0.1\n",
      "Epoch 16 / 300 \n",
      " - time (s): 2.3162472248077393 - sq_loss: 51205.01953125 - tot_loss: nan - acc: 0.1 - val_acc: 0.1\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-12-731ac57a7f61>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     41\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     42\u001b[0m     \u001b[0;31m# update W3, b3, V2 and U2\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 43\u001b[0;31m     \u001b[0mW3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mV2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mU2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mblock_update\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mW3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mW2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mV3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mU3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mV2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mU2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mV1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0md2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     44\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     45\u001b[0m     \u001b[0;31m# update W2, b2, V1 and U1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-9-a3c6f393bbdb>\u001b[0m in \u001b[0;36mblock_update\u001b[0;34m(Wn, bn, Wn_1, bn_1, Vn, Un, Vn_1, Un_1, Vn_2, dn_1, alpha, gamma, rho, dim)\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mblock_update\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mWn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mWn_1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbn_1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mVn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mUn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mVn_1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mUn_1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mVn_2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdn_1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0malpha\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0malpha\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgamma\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgamma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrho\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrho\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mN\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m     \u001b[0;31m# update W(n) and b(n)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m     \u001b[0mWn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mupdateWb_js\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mUn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mVn_1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mWn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0malpha\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrho\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      4\u001b[0m     \u001b[0;31m# update V(n-1)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m     \u001b[0mVn_1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mupdateV_js\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mUn_1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mUn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mWn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrho\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgamma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-6-347ff8d393e8>\u001b[0m in \u001b[0;36mupdateWb_js\u001b[0;34m(U, V, W, b, alpha, rho)\u001b[0m\n\u001b[1;32m      3\u001b[0m     \u001b[0mI\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meye\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m     \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcol_U\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mU\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m     \u001b[0mWstar\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0malpha\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mW\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mrho\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mU\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrepeat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mcol_U\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mV\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m                     \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minverse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0malpha\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mI\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mrho\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mV\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mV\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      6\u001b[0m     \u001b[0mbstar\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0malpha\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mrho\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mU\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mW\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mV\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrho\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mN\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0malpha\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mWstar\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbstar\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# Iterations\n",
    "print('Train on', N, 'samples, validate on', N_test, 'samples')\n",
    "for k in range(niter):\n",
    "    start = time.time()\n",
    "    '''\n",
    "    # update V4\n",
    "    V4 = (y_one_hot + gamma*U4 + alpha*V4)/(1 + gamma + alpha)\n",
    "    \n",
    "    # update U4 \n",
    "    U4 = (gamma*V4 + rho*(torch.mm(W4,V3) + b4.repeat(1,N)))/(gamma + rho)\n",
    "    '''\n",
    "    # update V11\n",
    "    V11 = (y_one_hot + gamma*U11 + alpha*V11)/(1 + gamma + alpha)\n",
    "    \n",
    "    # update U11\n",
    "    U11 = (gamma*V11 + rho*(torch.mm(W11, V10) + b11.repeat(1,N)))/(gamma + rho)\n",
    "    \n",
    "    # update W11, b11, V10 and U10\n",
    "    W11, b11, V10, U10 = block_update(W11, b11, W10, b10, V11, U11, V10, U10, V9, d10)\n",
    "    \n",
    "    # update W10, b10, V9 and U9\n",
    "    W10, b10, V9, U9 = block_update(W10, b10, W9, b9, V10, U10, V9, U9, V8, d9)\n",
    "    \n",
    "    # update W9, b9, V8 and U8\n",
    "    W9, b9, V8, U8 = block_update(W9, b9, W8, b8, V9, U9, V8, U8, V7, d8)\n",
    "    \n",
    "    # update W8, b8, V7 and U7\n",
    "    W8, b8, V7, U7 = block_update(W8, b8, W7, b7, V8, U8, V7, U7, V6, d7)\n",
    "    \n",
    "    # update W7, b7, V6 and U6\n",
    "    W7, b7, V6, U6 = block_update(W7, b7, W6, b6, V7, U7, V6, U6, V5, d6)\n",
    "    \n",
    "    # update W6, b6, V5 and U5\n",
    "    W6, b6, V5, U5 = block_update(W6, b6, W5, b5, V6, U6, V5, U5, V4, d5)\n",
    "    \n",
    "    # update W5, b5, V4 and U4\n",
    "    W5, b5, V4, U4 = block_update(W5, b5, W4, b4, V5, U5, V4, U4, V3, d4)\n",
    "    \n",
    "    # update W4, b4, V3 and U3\n",
    "    W4, b4, V3, U3 = block_update(W4, b4, W3, b3, V4, U4, V3, U3, V2, d3)\n",
    "    \n",
    "    # update W3, b3, V2 and U2\n",
    "    W3, b3, V2, U2 = block_update(W3, b3, W2, b2, V3, U3, V2, U2, V1, d2)\n",
    "    \n",
    "    # update W2, b2, V1 and U1\n",
    "    W2, b2, V1, U1 = block_update(W2, b2, W1, b1, V2, U2, V1, U1, x_train, d1)\n",
    "    \n",
    "    # update W1 and b1\n",
    "    W1, b1 = updateWb_js(U1, x_train, W1, b1, alpha, rho)\n",
    "\n",
    "    # compute updated training activations\n",
    "    _, a1_train = feed_forward(W1, b1, x_train)\n",
    "    _, a2_train = feed_forward(W2, b2, a1_train)\n",
    "    _, a3_train = feed_forward(W3, b3, a2_train)\n",
    "    _, a4_train = feed_forward(W4, b4, a3_train)\n",
    "    _, a5_train = feed_forward(W5, b5, a4_train)\n",
    "    _, a6_train = feed_forward(W6, b6, a5_train)\n",
    "    _, a7_train = feed_forward(W7, b7, a6_train)\n",
    "    _, a8_train = feed_forward(W8, b8, a7_train)\n",
    "    _, a9_train = feed_forward(W9, b9, a8_train)\n",
    "    _, a10_train = feed_forward(W10, b10, a9_train)\n",
    "    \n",
    "    \n",
    "    # training prediction\n",
    "    pred = torch.argmax(torch.addmm(b11.repeat(1, N), W11, a10_train), dim=0)\n",
    "    # pred = torch.argmax(torch.addmm(b4.repeat(1, N), W4, a3_train), dim=0)\n",
    "    \n",
    "    # compute test activations\n",
    "    _, a1_test = feed_forward(W1, b1, x_test, N_test)\n",
    "    _, a2_test = feed_forward(W2, b2, a1_test, N_test)\n",
    "    _, a3_test = feed_forward(W3, b3, a2_test, N_test)\n",
    "    _, a4_test = feed_forward(W4, b4, a3_test, N_test)\n",
    "    _, a5_test = feed_forward(W5, b5, a4_test, N_test)\n",
    "    _, a6_test = feed_forward(W6, b6, a5_test, N_test)\n",
    "    _, a7_test = feed_forward(W7, b7, a6_test, N_test)\n",
    "    _, a8_test = feed_forward(W8, b8, a7_test, N_test)\n",
    "    _, a9_test = feed_forward(W9, b9, a8_test, N_test)\n",
    "    _, a10_test = feed_forward(W10, b10, a9_test, N_test)\n",
    "    \n",
    "    # test/validation prediction\n",
    "    pred_test = torch.argmax(torch.addmm(b11.repeat(1, N_test), W11, a10_test), dim=0)\n",
    "    # pred_test = torch.argmax(torch.addmm(b4.repeat(1, N_test), W4, a3_test), dim=0)\n",
    "    \n",
    "    # compute training loss\n",
    "    loss1[k] = gamma/2*torch.pow(torch.dist(V11,y_one_hot,2),2).cpu().numpy()\n",
    "    # loss1[k] = gamma/2*torch.pow(torch.dist(V4,y_one_hot,2),2).cpu().numpy()\n",
    "    loss2[k] = loss1[k] \\\n",
    "    + compute_loss(W1, b1, x_train, U1) \\\n",
    "    + compute_loss(W2, b2, V1, U2) \\\n",
    "    + compute_loss(W3, b3, V2, U3) \\\n",
    "    + compute_loss(W4, b4, V3, U4) \\\n",
    "    + compute_loss(W5, b5, V4, U5) \\\n",
    "    + compute_loss(W6, b6, V5, U6) \\\n",
    "    + compute_loss(W7, b7, V6, U7) \\\n",
    "    + compute_loss(W8, b8, V7, U8) \\\n",
    "    + compute_loss(W9, b9, V8, U9) \\\n",
    "    + compute_loss(W10, b10, V9, U10) \\\n",
    "    + compute_loss(W11, b11, V10, U11) \n",
    "    \n",
    "    # compute training accuracy\n",
    "    correct_train = pred == y_train\n",
    "    accuracy_train[k] = np.mean(correct_train.cpu().numpy())\n",
    "    \n",
    "    # compute validation accuracy\n",
    "    correct_test = pred_test == y_test\n",
    "    accuracy_test[k] = np.mean(correct_test.cpu().numpy())\n",
    "    \n",
    "    # compute training time\n",
    "    stop = time.time()\n",
    "    duration = stop - start\n",
    "    time1[k] = duration\n",
    "    \n",
    "    # print results\n",
    "    print('Epoch', k + 1, '/', niter, '\\n', \n",
    "          '-', 'time (s):', time1[k], '-', 'sq_loss:', loss1[k], '-', 'tot_loss:', loss2[k], \n",
    "          '-', 'acc:', accuracy_train[k], '-', 'val_acc:', accuracy_test[k])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualization of training results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "plt.plot(np.arange(1, niter + 1), loss2)\n",
    "plt.title('training loss')\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(np.arange(1, niter + 1), accuracy_test)\n",
    "plt.title('validation accuracy')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
